#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# PisiLinux initramfs creator
#

import os
import sys
import glob
import stat
import shutil
import tempfile
import subprocess

from optparse import OptionParser

config = {"rootDir"         : "/",
          "tmpDir"          : "",
          "debug"           : False,
          "dryrun"          : False,
          "destDir"         : "/boot",
          "blackList"       : ["pktcdvd", "floppy"],
          "kernelType"      : "kernel",
          "initramfsConf"   : "/etc/initramfs.conf",
          "kernelVersion"   : "",
}

def loadFile(_file):
    try:
        f = file(_file)
        d = [a.lstrip().rstrip("\n") for a in f]
        d = filter(lambda x: not (x.startswith("#") or x == ""), d)
        f.close()
        return d
    except:
        return []

def writeFile(_file, data):
    f = open(_file, "w")
    f.write(data)
    f.close()

def mkdir(_dir):
    os.makedirs(_dir)

def dohardlink(destination, source):
    os.link(destination, source)

def dosymlink(destination, source):
    os.symlink(destination, source)

def mknod(nodfile, nodtype, major, minor, perms=0666):
    # c for character b for block devices, mknod style
    if nodtype == "c":
        devtype = stat.S_IFCHR
    else:
        devtype = stat.S_IFBLK

    os.mknod(nodfile, perms | devtype, os.makedev(major, minor))

def copy(source, destination):
    # First let's check for the target directory and create if
    # it doesn't exist
    destdir = os.path.dirname(destination)
    if not os.path.isdir(destdir):
        mkdir(destdir)

    try:
        shutil.copy2(source, destination)
    except IOError:
        printWarn("Could not find %s" % source)

def touch(_file):
    if os.path.exists(_file):
        os.utime(_file, None)
    else:
        f = open(_file, 'w')
        f.close()

def capture(*cmd):
    a = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    return a.communicate()

def run(*cmd):
    f = file("/dev/null", "w")
    return subprocess.call(cmd, shell=True, stdout=f, stderr=f)

def printFail(msg):
    print "ERROR: %s" % msg

    tempdir.cleanup()
    sys.exit(1)

def printWarn(msg):
    print "WARNING: %s" % msg

def setKernelVersion(Version=""):
    if Version == "":
        kver = "".join(loadFile("/etc/kernel/%s" % config["kernelType"]))
    else:
        kver = Version

    # FIXME: we should do this an option and in a try/except
    if kver == "":
        print "could not find version of %s, autodetecting" % config["kernelType"]
        kver = os.uname()[2]

    config["kernelVersion"] = kver

class BaseSystem:
    def __init__(self):
        self.kver = config["kernelVersion"]
        self.tmpDir = config["tmpDir"]
        self.baseDirs = ["bin", "sbin", "etc/initramfs.d", "dev", "dev/loop", "lib", "newroot", "proc", "sys"]

        self.deviceNodes = {"null"      : ["c", 1, 3],
                            "console"   : ["c", 5, 1],
                            "tty"       : ["c", 5, 0],
                            "tty0"      : ["c", 4, 0],
                            "tty1"      : ["c", 4, 1],
                            "ram0"      : ["b", 1, 0],
                            "fb0"       : ["c", 29, 0],
                            "urandom"   : ["c", 1, 9],
                            "kmsg"      : ["c", 1, 11],
                           }

        self.baseFileList = {
                                 "/bin/busybox"                 : "/bin/",
                                 "/bin/busybox.links"           : "/bin/",
                                 "/usr/bin/disktype"            : "/bin/",
                                 "/usr/bin/sdmem"               : "/bin/",
                                 "/lib/initramfs/init"          : "/",
                                 "/lib/initramfs/hotplug"       : "/sbin/",
                                 "/lib/initramfs/udhcpc.script" : "/etc/",
                                 "/lib/initramfs/profile.rc"    : "/etc/profile",
                            }

        self.modprobeBlacklistFileList = dict([(k,"/etc/modprobe.d") for k in glob.glob("/etc/modprobe.d/blacklist*conf")])

        self.suspendFileList = {
                                    "/etc/suspend.conf"         : "/etc/",
                                    "/usr/sbin/resume"          : "/bin/",
                               }


    def getNewRoot(self, dir):
        cleandir = dir.lstrip("/")
        return os.path.join(self.tmpDir, cleandir)

    def createBaseDirectories(self):
        for i in self.baseDirs:
            mkdir(self.getNewRoot(i))

        mkdir(self.getNewRoot("/lib/modules/%s" % self.kver))
        mkdir(self.getNewRoot("/lib/firmware"))
        mkdir(self.getNewRoot("/etc/modprobe.d"))

    def copyBasefiles(self):
        for i in self.baseFileList:
            copy(i, self.getNewRoot(self.baseFileList[i]))

        for i in ["/init", "/etc/udhcpc.script"]:
            os.chmod(self.getNewRoot(i), 0755)

        for i in ["ext2", "ext3", "ext4", "reiserfs", "xfs"]:
            touch(self.getNewRoot("/bin/fsck.%s" % i))

        writeFile(self.getNewRoot("/etc/fstab"), "none none none defaults 0 0")

    def createConfig(self):
        # FIXME: Parse config files and create a proper one by hand
        configFileSource = config["initramfsConf"]
        if os.path.exists(configFileSource):
            copy(configFileSource, self.getNewRoot(configFileSource))

    def createNodes(self):
        for i in self.deviceNodes:
            k = self.deviceNodes[i]
            mknod(self.getNewRoot("dev/%s" % i), k[0], k[1], k[2])

        for i in range(8):
            mknod(self.getNewRoot("dev/loop/%i" % i), "b", 7, i)
            dosymlink("loop/%i" % i, self.getNewRoot("dev/loop%i" % i))

    def createBaseSymlinks(self):
        for i in loadFile(self.getNewRoot("/bin/busybox.links")):
            # FIXME: python dosym does not play nice with cpio, it creates 120MB initramfs
            # dosymlink("busybox", self.getNewRoot("/bin/%s" % i.split("/")[-1]))
            dohardlink(self.getNewRoot("/bin/busybox"), self.getNewRoot("/bin/%s" % i.split("/")[-1]))

        # FIXME: maybe we should symlink sbin to bin
        dohardlink(self.getNewRoot("/bin/busybox"), self.getNewRoot("/sbin/modprobe"))

    def addRaid(self):
        mdadmFile = "/sbin/mdadm.static"
        if os.path.exists(mdadmFile):
            copy(mdadmFile, self.getNewRoot(mdadmFile.replace(".static", "")))

    def addLvm(self):
        lvmFile = "/sbin/lvm.static"
        if os.path.exists(lvmFile):
            copy(lvmFile, self.getNewRoot(lvmFile.replace(".static", "")))

    def addSuspend(self):
        for i in self.suspendFileList:
            copy(i, self.getNewRoot(self.suspendFileList[i]))

    def addModprobeBlacklists(self):
        for i in self.modprobeBlacklistFileList:
            copy(i, self.getNewRoot(self.modprobeBlacklistFileList[i]))

    def create(self):
        self.createBaseDirectories()
        self.copyBasefiles()
        self.createNodes()
        self.createBaseSymlinks()
        self.createConfig()
        self.addRaid()
        self.addLvm()
        self.addSuspend()
        self.addModprobeBlacklists()

class KernelModule:
    def __init__(self):
        self.modulesList = []
        self.allModules = []
        self.blackList = config["blackList"]
        self.kernelVersion = config["kernelVersion"]

        self.targetDir = config["tmpDir"]
        self.rootDir = config["rootDir"]
        self.modulesDir = os.path.join(self.rootDir, "lib/modules/%s" % self.kernelVersion)
        self.firmwareDir = os.path.join(self.rootDir, "lib/firmware")

        self.addNetworkModule = config["networkModule"]
        self.addNetworkModuleBasic = config["networkModuleBasic"]

        self.addDRMModules = not config["excludeDRM"]
        self.findAllModules()

        self.scsiDirs = ["kernel/drivers/scsi"]
        self.scsiModules = ["mptfc", "mptsas", "mptscsih", "mptspi", "zfcp"]

        self.mdDirs = ["kernel/drivers/md"]
        self.ataDirs = ["kernel/drivers/ata"]
        self.mmcDirs = ["kernel/drivers/mmc"]
        self.ideDirs = ["kernel/drivers/ide"]
        self.blockDirs = ["kernel/drivers/block"]

        self.firewireModules = ["firewire-ohci", "firewire-sbp2", "firewire-net", "firewire-core"]
        self.i2oModules = ["i2o_block"]
        self.usbModules = ["usb-storage", "sd_mod", "usbcore", "ehci-hcd", "ohci-hcd", "uhci-hcd"]

        self.filesystemModules = ["ext2", "ext3", "ext4", "reiser4", "jfs", "reiserfs", "xfs", "vfat", "fat", "ntfs", "unionfs", "cramfs", "nfs", "nls_utf8", "nls_iso8859_9", "nls_cp857", "nls_iso8859-1", "nls_ascii", "nls_cp850", "squashfs", "aufs", "udf"]

        self.networkBaseModules = ["af_packet", "mii", "8390", "via-rhine", "8139too", "ne2k-pci", "e100", "sky2", "tg3", "skge"]
        self.networkDirs = ["kernel/drivers/net"]

        self.drmDir = "kernel/drivers/gpu/drm"

        self.virtioModules = ["virtio", "virtio_balloon", "virtio_blk", "virtio_net", "virtio_pci"]
        self.xenModules = ["xenblk", "xenfb", "gntdev", "xennet", "xenkbd"]


    def tidyModuleList(self):
        ml = list(set(self.modulesList))
        ml.sort()
        self.modulesList = ml

    def depmod(self, targetdir):
        cmd = "/sbin/depmod -a -b %s -F %s/System.map %s"
        capture(cmd % (targetdir, self.modulesDir, self.kernelVersion))

    def installModules(self):
        for i in self.modulesList:
            if os.path.basename(i).replace(".ko", "") not in self.blackList:
                copy(i, os.path.join(self.targetDir, self.modulesDir.replace(self.rootDir, "/").lstrip("/")))

    def findAllModules(self):
        foundModules = []

        if not os.path.exists(self.modulesDir):
            printFail("There is no %s, please check your kernel version" % self.modulesDir)

        for root, directory, files in os.walk(self.modulesDir):
            for name in files:
                if name.endswith(".ko"):
                    foundModules.append(os.path.join(root, name))

        self.allModules = foundModules

    def findModuleDeps(self):
        deps = []
        depdata = loadFile(os.path.join(self.modulesDir, "modules.dep"))

        for line in depdata:
            for i in self.modulesList:
                if "%s:" % i.replace("%s/" % self.modulesDir, "") in line:
                    deps.extend(line.split(":")[1].strip().split(" "))

        # return [os.path.join(self.modulesDir, x) for x in deps if not x == ""]
        self.modulesList.extend([os.path.join(self.modulesDir, x) for x in deps if not x == ""])

    def appendFirmwares(self):
        ret = capture("/sbin/modinfo -F firmware %s/*.ko" % os.path.join(self.targetDir, self.modulesDir.replace(self.rootDir, "/").lstrip("/")))[0]
        fwlist = ret.strip("\n").split("\n")

        targetbase = os.path.join(self.targetDir, "lib/firmware")

        for i in fwlist:
            src = os.path.join(self.firmwareDir, i)
            target = os.path.join(targetbase, os.path.dirname(i))

            if os.path.exists(src):
                if not os.path.exists(target):
                    mkdir(target)

                copy(src, target)

            else:
                printWarn("Could not find firmware %s" % src)
                pass

    def updateModuleDependencies(self, force=False):
        # this is here to prevent a race condition with pakhandler, when installing a system from scratch
        if force or not os.path.exists("%s/modules.dep" % self.modulesDir):
            printWarn("Could not find module dependencies in %s, running depmod for system" % self.modulesDir)
            self.depmod(self.rootDir)

    def addDir(self, mdir):
        newModules = []
        dirtoadd = os.path.join(self.modulesDir, mdir)

        for line in self.allModules:
            if line.startswith(dirtoadd):
                newModules.append(line)

        self.modulesList.extend(newModules)

    def addModule(self, module):
        if module.endswith(".ko"):
            module = module[:-3]

        for line in self.allModules:
            # FIXME: I still don't trust the _ versus - case, try to be sure of it in kernel
            if line.replace("-", "_").endswith("/%s.ko" % module.replace("-", "_")):
                self.modulesList.append(line)
                return

    def addScsi(self):
        for i in self.scsiDirs:
            self.addDir(i)

        for i in self.scsiModules:
            self.addModule(i)

    def addAta(self):
        for i in self.ataDirs:
            self.addDir(i)

    def addMmc(self):
        for i in self.mmcDirs:
            self.addDir(i)

    def addBlock(self):
        for i in self.blockDirs:
            self.addDir(i)

    def addIde(self):
        for i in self.ideDirs:
            self.addDir(i)

    def addNetwork(self):
        if self.addNetworkModule:
            for i in self.networkDirs:
                self.addDir(i)

        if self.addNetworkModuleBasic:
            for i in self.networkBaseModules:
                self.addModule(i)

    def addDRM(self):
        if self.addDRMModules:
            for i in glob.glob("%s/*/*.ko" % os.path.join(self.modulesDir, self.drmDir)):
                # Check for drm_crtc_init symbol
                if not os.system("grep -qw drm_crtc_init %s" % i):
                    self.addModule(i.partition(self.modulesDir+'/')[-1])

            # Add uvesafb as a fallback
            if os.path.exists("/sbin/v86d"):
                copy("/sbin/v86d", os.path.join(self.targetDir, "sbin/"))
                self.addModule("kernel/drivers/video/uvesafb.ko")

    def addMd(self):
        for i in self.mdDirs:
            self.addDir(i)

    def addFirewire(self):
        for i in self.firewireModules:
            self.addModule(i)

    def addI2o(self):
        for i in self.i2oModules:
            self.addModule(i)

    def addUsb(self):
        for i in self.usbModules:
            self.addModule(i)

    def addFilesystem(self):
        for i in self.filesystemModules:
            self.addModule(i)

    def addVirtio(self):
        for i in self.virtioModules:
            self.addModule(i)

    def addXen(self):
        for i in self.xenModules:
            self.addModule(i)

    def addGeneric(self):
        self.addScsi()
        self.addAta()
        self.addMmc()
        self.addBlock()
        self.addIde()
        self.addNetwork()
        self.addMd()
        self.addFirewire()
        self.addI2o()
        self.addUsb()
        self.addFilesystem()
        self.addVirtio()
        self.addDRM()

    def autoGenerate(self):
        self.updateModuleDependencies()
        self.addGeneric()
        self.findModuleDeps()
        self.tidyModuleList()
        self.installModules()
        self.appendFirmwares()
        self.depmod(self.targetDir)

class Plymouth:
    def __init__(self):
        self.fileList = "/lib/initramfs/plymouth.list"
        self.targetDir = config["tmpDir"]
        self.theme = "pisilinux"
        self.themeDir = "/usr/share/plymouth/themes"

    def install_current_theme(self):
        self.theme = os.popen("plymouth-set-default-theme").read().strip()
        if os.path.exists("%s/%s" % (self.themeDir, self.theme)):
            for themeFile in glob.glob("%s/%s/*" % (self.themeDir, self.theme)):
                copy(themeFile, os.path.join(self.targetDir, themeFile[1:]))

    def install(self):
        """Will install plymouth related base stuff."""
        if os.path.exists(self.fileList):
            with open(self.fileList, "r") as files:
                for _file in files:
                    filename = _file.strip()
                    copy(filename.strip(), os.path.join(self.targetDir, filename[1:]))

        self.install_current_theme()

class Initramfs:
    def __init__(self):
        self.destDir = config["destDir"]
        self.sourceDir = config["tmpDir"]
        self.initramfs = "initramfs-%s" % config["kernelVersion"]

    def create(self):
        if not os.path.exists(self.destDir):
            mkdir(self.destDir)

        cmd = "(cd %s && find . | cpio --quiet --dereference -o -H newc | gzip -6 > %s)"
        capture(cmd % (self.sourceDir, os.path.join(self.destDir, self.initramfs)))

class Tempdir:
    def __init__(self):
        self.tmpDir = ""
        self.keepTmp = False

    def create(self):
        self.tmpDir = tempfile.mkdtemp(prefix="mkinitramfs-")

    def cleanup(self):
        if self.keepTmp:
            printWarn("Keeping temporary directory %s" % self.tmpDir)
        else:
            shutil.rmtree(self.tmpDir)


if __name__ == "__main__":
    tempdir = Tempdir()
    tempdir.create()
    config["tmpDir"] = tempdir.tmpDir

    parser = OptionParser()
    parser.add_option("-k", "--kernel", dest="kernelVersion", type="string",
            help="kernel version to create initramfs for")

    parser.add_option("-t", "--type", dest="type", type="string", default="kernel",
            help="kernel type to create initramfs for")

    parser.add_option("-o", "--output", dest="destDir", type="string", metavar="DIR", default="/boot",
            help="create initramfs in DIR")

    parser.add_option("-c", "--configfile", dest="configFile", type="string", metavar="FILE", default="/etc/initramfs.conf",
            help="use FILE for initramfs config file, default is /etc/initramfs.conf")

    parser.add_option("-r", "--rootdir", dest="rootDir", type="string", metavar="DIR", default="/",
            help="use DIR as basedir for kernel modules")

    parser.add_option("-f", "--filename", dest="filename", type="string", metavar="FILE",
            help="use FILE for initramfs file name")

    parser.add_option("-d", "--debug", action="store_true", dest="debug", default=False,
            help="print extra debug info")

    parser.add_option("--blacklist", dest="blackList", type="string", metavar="FILES", default="",
            help="define modules to be blacklisted, seperated by comma. Example: e100,rtl819,ahci")

    parser.add_option("--network", action="store_true", dest="networkModule", default=False,
            help="add network modules")

    parser.add_option("--nodrm", action="store_true", dest="excludeDRM", default=False,
            help="Don't include KMS capable DRM modules")

    parser.add_option("--network-generic", action="store_true", dest="networkModuleBasic", default=False,
            help="add only generic network modules")

    parser.add_option("--keeptmp", action="store_true", dest="keepTmp", default=False,
            help="whether to keep temporary dir after operation")

    parser.add_option("-n", "--dry-run", action="store_true", dest="dryrun", default=False,
            help="do not perform any action, just show what will be done")

    parser.add_option("--list-modules", action="store_true", dest="listModules", default=False,
            help="do not perform any action, just show what will be done")

    parser.add_option("--list-base", action="store_true", dest="listBase", default=False,
            help="do not perform any action, just show what will be done")

    (opts, args) = parser.parse_args()

    config["initramfsConf"] = opts.configFile

    config["destDir"] = os.path.abspath(opts.destDir)
    config["rootDir"] = os.path.abspath(opts.rootDir)

    config["debug"] = opts.debug
    config["dryrun"] = opts.dryrun
    tempdir.keepTmp = opts.keepTmp

    config["kernelType"] = opts.type
    config["networkModule"] = opts.networkModule
    config["networkModuleBasic"] = opts.networkModuleBasic
    config["excludeDRM"] = opts.excludeDRM

    if "," in opts.blackList:
        config["blackList"].extend(opts.blackList.split(","))

    if opts.kernelVersion:
        config["kernelVersion"] = opts.kernelVersion
    else:
        setKernelVersion()

    basesystem = BaseSystem()
    basesystem.create()

    # Check for plymouth and install if found
    plymouth = Plymouth()
    plymouth.install()

    kernelmodule = KernelModule()
    kernelmodule.autoGenerate()

    initramfs = Initramfs()
    initramfs.create()

    tempdir.cleanup()

